from utils import create_one_dim_inf_dim_market, solve_one_dim_inf_dim_market
from json import load
import numpy as np
from matplotlib import pyplot as plt
import matplotlib
import pandas as pd
import seaborn as sns

font = {'size': 24}
sns.set_theme()

n = 50
B, c, d = create_one_dim_inf_dim_market(n, sd=2022, normalize=True)
u_eq_true = solve_one_dim_inf_dim_market(B, c, d)
nsw_true = (B * np.log(u_eq_true)).sum()

loaded = []
for sd in range(1, 11):
    with np.load(f'results/nsw_list_n_50_one_dim_sd_{sd}.npz') as data:
        if sd == 1: t_list = data['t_list']
        nsw_list = data['nsw_list']
        loaded.append(nsw_list)
loaded = np.array(loaded)

mean_arr = loaded.mean(axis=0)
se_arr = loaded.std(axis=0) / loaded.shape[0]
plt.errorbar(t_list, mean_arr, se_arr, label=r'NSW$^{\gamma}$')
plt.plot(t_list, [nsw_true]*len(t_list), label=r'NSW$^*$')
plt.fill_between(t_list, loaded.min(axis=0), loaded.max(axis=0), alpha=0.5)
plt.xlabel(r'$t$')
plt.ylabel(r'NSW')
plt.legend()
# plt.title(r'$v_i(\theta) = \alpha_i^\top \theta + c_i$')
plt.savefig('plots/nsw_one_dim_increasing_t.pdf', bbox_inches = 'tight')